/*
    This file is a part of ficus language project.
    See ficus/LICENSE for the licensing terms
*/

/*
    Full-scale lambda lifting.
    Unlike k_simple_ll, this algorithm
    moves all the functions to the top level.

    Also, unlike k_simple_ll, this algorithm is executed after all
    other K-form optimizations, including dead code elimination,
    tail-recursive call elimination, inline expansion etc.
    (But dead code elimination runs once again after it to remove
    some of the generated by this step structures and functions that are not used)

    In order to move all the functions to the top level and preserve
    semantics of the original code, we need to transform some nested functions,
    as well as some outer code that uses those functions.

    * We analyze each function and see if the function has 'free variables',
      i.e. variables that are defined outside of this function and yet are non-global.
    a)  If the function has some 'free variables',
        it needs a special satellite structure called 'closure data'
        that incapsulates all the free variables. The function itself
        is transformed, it gets an extra parameter, which is pointer to the
        closure data (this is done at C code generation step actually).
        All the accesses to free variables are replaced
        with the closure data access operations. Then, when the function
        occurs in code (if it does not occur, it's eliminated as dead code),
        a 'closure' is created, which is a pair (function_pointer, (closure_data or 'nil')).
        This pair is used instead of the original function. Here is the example:

        fun foo(n: int) {
            fun bar(m: int) = m * n
            bar
        }

        is replaced with

        fun bar(m: int, c: bar_closure_t) {
            m * c->n
        }
        fun foo(n: int) {
            make_closure(bar, bar_closure_t {n})
        }

    b)  If the function does not have any 'free variables', we may still
        need a closure, i.e. we may need to represent this function as a pair,
        because in general when we pass a function as a parameter to another function
        or store it as a value (essentially, we store a function pointer), the caller
        of that function does not know whether the called function needs free variables or not,
        so we need a consistent representation of functions that are called indirectly.
        But in this case we can have a pair ('some function', nil), i.e. we just use something like
        NULL pointer instead of a pointer to real closure. So, the following code:

        fun foo(n: int) {
            fun bar(m: int) = m*n // uses free variable 'n'
            fun baz(m: int) = m+1 // does not use any free variables
            if generate_random_number() % 2 == 0 {bar} else {baz}
        }
        val transform_f = foo(5)
        for i <- 0:10 {println(transform_f(i))}

        is transformed to:

        fun bar( m: int, c: bar_closure_t* ) = m*c->n
        fun baz( m: int, _: nil_closure_t* ) = m+1

        fun foo(n: int) =
            if generate_random_number() % 2 == 0
                {make_closure(bar, bar_closure_t {n})}
            else
                {make_closure(baz, nil)}

        val (transform_f_ptr, transform_f_cldata) = foo(5)
        for i <- 0:10 {println(transform_f_ptr(i, transform_f_cldata))}

        However, in the case (b) when we call the function directly, e.g.
        we call 'baz' as 'baz' directly, not via 'transform_f' pointer, we can
        avoid the closure creation step and just call it as, e.g., 'baz(real_arg, nil)'.

    From the above description it may seem that the algorithm is very simple,
    but there are some nuances:

    1.  The nested function with free variables may not just read some values
        declared outside, it may modify mutable values, i.e. var's.
        Or, it may read from a 'var', and yet it may call another nested function
        that may access the same 'var' and modify it. We could have stored an address
        of each var in the closure data, but that would be unsafe, because we may
        return the created closure outside of the function
        (which is a typical functional language pattern for generators, see below)
        where 'var' does not exist anymore. The robust solution for this problem is
        to convert each 'var', which is used at least once as a free variable,
        into a reference, which is allocated on heap:

        fun create_incremetor(start: int) {
            var v = start
            fun inc_me() {
                val temp = v; v += 1; temp
            }
            inc_me
        }
        val counter = create_incrementor(5)
        for i<-0:10 {print(f"{counter()} ")} // 5 6 7 8 ...

        this is converted to:

        fun inc_me( c: inc_me_closure_t* ) {
            val temp = *c->v; *c->v += 1; temp
        }
        fun create_incrementor(start: int) {
            val v = ref(start)
            make_closure(inc_me, inc_me_closure_t {v})
        }

        Note: currently this behavior is implemented in
        K_freevars.mutable_freevars_referencing function.

    2.  Besides the free variables, the nested function may also call:
        2a. itself. This is a simple case. We just call it and pass the same closure data, e.g.:
          fun bar( n: int, c: bar_closure_t* ) = if n <= 1 {1} else { ... bar(n-1, c) }
        2b. another function that needs some free variables from the outer scope
            fun foo(n: int) {
                fun bar(m: int) = baz(m+1)
                fun baz(m: int) = m*n
                (bar, baz)
            }

            in order to form the closure for 'baz', 'bar' needs to read 'n' value,
            which it does not access directly. That is, the code can be converted to:

            // option 1: dynamically created closure
            fun bar( m: int, c:bar_closure_t* ) {
                val (baz_cl_f, baz_cl_fv) = make_closure(baz, baz_closure_t {c->n})
                baz_cl_f(m+1, baz_cl_fv)
            }
            fun baz( m: int, c:baz_closure_t* ) = m*c->n
            fun foo(n: int) = (make_closure(bar, bar_closure_t {n}),
                               make_closure(baz, baz_closure_t {n})

            or it can be converted to

            // option 2: nested closure
            fun bar( m: int, c:bar_closure_t* ) {
                val (baz_cl_f, baz_cl_fv) = c->baz_cl
                baz_cl_f(m+1, baz_cl_fv)
            }
            fun baz( m: int, c:baz_closure_t* ) = m*c->n
            fun foo(n: int) = {
                val baz_cl = make_closure(baz, baz_closure_t {n})
                val bar_cl = make_closure(bar, {baz_cl})
                (bar_cl, baz_cl)
            }

            or it can be converted to

            // option 3: shared closure
            fun bar( m: int, c:foo_nested_closure_t* ) {
                baz(m+1, c)
            }
            fun baz( m: int, c:foo_nested_closure_t* ) = m*c->n
            fun foo(n: int) = {
                val foo_nested_closure_data = foo_nested_closure_t {n}
                val bar_cl = make_closure(bar, foo_nested_closure_data)
                val baz_cl = make_closure(baz, foo_nested_closure_data)
                (bar_cl, baz_cl)
            }

        The third option in this example is the most efficient. But in general it may
        be not easy to implement, because between bar() and baz() declarations there
        can be some value definitions, i.e. in general baz() may access some values that
        are computed using bar(), and then the properly initialized shared closure may be
        difficult to build.

        The second option is also efficient, because we avoid repetitive call to
        make_closure() inside bar(). However if not only 'bar' calls 'baz',
        but also 'baz' calls 'bar', it means that both closures need to reference each other,
        so we have a reference cycle and this couple of closures
        (or, in general, a cluster of closures) will never be released.

        So, for simplicity, we just implement the first, i.e. the slowest option.
        It's not a big problem though, because:
        * the language makes an active use of dynamic data structures
          (recursive variants, lists, arrays, strings, references ...) anyway,
          and so the memory allocation is used often, but it's tuned to be efficient,
          especially for small memory blocks.
        * when we get to the lambda lifting
          stage, we already have expanded many function calls inline
        * the remaining non-expanded nested functions that do not need
          free variables are called directly without creating a closure
        * when we have some critical hotspots, we can transform the critical functions and
          pass some 'free variables' as parameters in order to eliminate closure creation.
          (normally hotspot functions operate on huge arrays and/or they can be
          transformed into tail-recursive functions, so 'mutually-recursive functions'
          and 'hotspots' are the words that rarely occur in the same sentence).
        [TODO] The options 1, 2 and 3 are not mutually exclusive; for example, we can use
        the most efficient 3rd option in some easy-to-detect partial cases
        (e.g. when the nested functions go sequentially without any non-trivially
        defined values between them) and use the first option everywhere else.

    3.  In order to implement the first option (2.2b.1) above we need to create an
        iterative algorithm to compute the extended sets of free variables for each
        function. First, for each function we find the directly accessed free variables.
        Then we check which functions we call from this function and combine their
        free variables with the directly accessible ones, and add free variables
        from functions they call etc. We continue to do so until all the sets of
        free variables for all the functions are stabilized and do not change
        on the next iteration.


    Note: It's assumed there, that handled code don't have mutable freevars. That means, that
    lift_all can be called only after K.freevars.mutable_freevars_referencing
*/

from Ast import *
from K_form import *
from K_freevars import *
import K_lift_simple
import Hashmap, Hashset

fun lift_all(kmods: kmodule_t list)
{
    type ll_subst_env_t = (id_t, (id_t, ktyp_t?)) Hashmap.t
    fun empty_subst_env(size0: int): ll_subst_env_t = Hashmap.empty(size0, noid, (noid, None))

    val globals = empty_id_hashset(256)
    for {km_top} <- kmods {
        K_lift_simple.update_globals(km_top, globals)
    }

    var fv_env : fv_env_t = collect_free_vars(kmods, globals, true, false, true)

    /*
    Sometimes recursive function with freevars mention itself not through direct call, but
    through passing itself to other function (directly or through variable). E.g:
    {
        var i = 1
        fun dipper(beeper:void->void)
        {
            if(i<1000)
            {
                beeper()
            }
        }

        fun bill()
        {
            i = i * 2
            dipper(bill)
        }

        bill()
    }
    In this case it's needed internal copy of function closure for passing as an argument.
    Such a special case of recursion is called self-referencing.
    */
    val self_referencing_functions = empty_id_hashset(256)
    var curr_folded_func = noid

    fun fold_defcl_atom_(a: atom_t, loc: loc_t, callb: k_fold_callb_t) =
        match a {
        | AtomId n => if n == curr_folded_func {self_referencing_functions.add(n)}
        | _ => {}
        }

    fun fold_defcl_ktyp_(t: ktyp_t, loc: loc_t, callb: k_fold_callb_t) {}

    // form a closure for each function that has free variables
    fun fold_defcl_kexp_(e: kexp_t, callb: k_fold_callb_t) {
        match e {
        | KDefFun kf =>
            val {kf_name, kf_params, kf_rt, kf_body, kf_closure, kf_flags, kf_scope, kf_loc} = *kf
            val curr_folded_func_backup = if (fv_env.mem(kf_name)) {
                    val curr_folded_func_backup = curr_folded_func
                    curr_folded_func = kf_name
                    curr_folded_func_backup
                } else { noid }
            fold_kexp(kf_body, callb)
            if (fv_env.mem(kf_name)) {curr_folded_func = curr_folded_func_backup}
            match fv_env.find_opt(kf_name) {
            | Some ll_info =>
                val fvars = ll_info.fv_fvars
                if !fvars.empty() {
                    val fvars_final = sort_freevars(fvars)
                    val m_idx = kf_name.m
                    val fcv_tn = gen_idk(m_idx, pp(kf_name) + "_closure")
                    val fold fvars_wt = [] for fv@idx <- fvars_final {
                        val {kv_typ, kv_flags, kv_loc} = get_kval(fv, kf_loc)
                        if is_mutable(fv, get_idk_loc(fv, noloc)){
                            throw compile_err(kf_loc,
                                f"free variable '{idk2str(fv, get_idk_loc(fv, noloc))}' must not be mutable on this stage. \
                                Run K.freevars.mutable_freevars_referencing before.")
                        }
                        // A function's free variable is a non-local and yet non-global variable
                        // that the function depends on. Since it's not global, it must be in
                        // some outer function w.r.t. the current function.
                        // And so it must be in the same module as the analyzed function.
                        val new_fv = dup_idk(m_idx, fv)
                        val _ = create_kdefval(new_fv, kv_typ, kv_flags, None, [], kv_loc)
                        (new_fv, kv_typ) :: fvars_wt
                    }
                    val fcv_t = ref (kdefclosurevars_t {
                            kcv_name=fcv_tn,
                            kcv_cname="",
                            kcv_freevars=fvars_wt.rev(),
                            kcv_orig_freevars=fvars_final,
                            kcv_scope=kf_scope,
                            kcv_loc=kf_loc
                            })
                    val make_args_ktyps = [:: for (_, t) <- fvars_wt {t} ]
                    val kf_typ = get_kf_typ(kf_params, kf_rt, kf_loc)
                    val cl_arg = gen_idk(m_idx, "cv")
                    val make_fp = gen_idk(m_idx, "make_fp")
                    val _ = create_kdefconstr(make_fp, make_args_ktyps.rev(), kf_typ,
                                              CtorFP(kf_name), false, [], kf_scope, kf_loc)
                    val _ = create_kdefval(cl_arg, KTypName(fcv_tn), default_val_flags(), None, [], kf_loc)

                    /*
                        kf_closure.kci_arg is the id of argument that is used
                        to pass closure data to each function.

                        Initially, kci_arg is set to 'noid' for each function.
                        Which means that the function does not use any free variables.
                        It will still have this void* fx_fv parameter, but it will not be used.
                        During this lambda lifting step we find out which functions use free variables,
                        And we give a special name to the corresponding parameter
                        (because we need to extract free variables from the closure in the function prologue).

                        kf_closure.kci_fcv_t is the name of type that represents the function closure data
                        structure (which is dynamically allocated and which stores the reference counter,
                        pointer to the destructor and the free variables themselves).

                        kf_closure.kci_make_fp is the constructor of the closure
                    */
                    val new_kf_closure = kf_closure.{kci_arg=cl_arg, kci_fcv_t=fcv_tn, kci_make_fp=make_fp}
                    set_idk_entry(fcv_tn, KClosureVars(fcv_t))
                    *kf = kf->{kf_closure=new_kf_closure, kf_flags=kf_flags.{fun_flag_uses_fv=true}}
                }
            | _ => {}
            }
        | KExpCall(f, args, (_, loc)) =>
            for a <- args {fold_defcl_atom_(a,loc,callb)}
        | _ => fold_kexp(e, callb)
        }
    }

    val defcl_callb = k_fold_callb_t
    {
        kcb_fold_atom=Some(fold_defcl_atom_),
        kcb_fold_ktyp=Some(fold_defcl_ktyp_),
        kcb_fold_kexp=Some(fold_defcl_kexp_)
    }

    // we keep track of all defined values. It helps us to detect situations
    // when KForm is broken after some optimizations and a value is used
    // before it's declared.
    var defined_so_far = empty_id_hashset(256)
    var curr_m_idx = -1
    var curr_clo = noid
    var curr_top_code : kcode_t = []
    var curr_subst_env : ll_subst_env_t = empty_subst_env(256)
    var curr_lift_extra_decls : kcode_t = []

    // During lambda lifting we we need not only to transform the lifted functions'
    // bodies and function calls expressions, we potentially need to transform various atoms
    // used in expressions, as right-hand-side values in val/var declaration etc.
    fun walk_atom_n_lift_all(a: atom_t, loc: loc_t, callb: k_callb_t) =
        match a {
        | AtomId({m=0}) => a
        | AtomId n =>
            match curr_subst_env.find_opt(n) {
            | Some((_, Some f_typ)) =>
                /*
                    The atom is id, corresponding to a function that does not have/need a closure.
                    But, since we are here, it means that the function is not called directly,
                    but is rather used as a value, i.e. it's passed to another function,
                    returned from a function or is stored somewhere.

                    We need to construct a light-weight closure on-fly, which will be a pair
                    (function_pointer, <null pointer to closure data>). It's done with
                    KExpMkClosure with the 'noid' constructor and the empty list of free vars.

                    But here is the trick. walk_atom() hook can only return an atom.
                    Therefore we put the extra closure definition into 'curr_lift_extra_decls',
                    which is then handled in walk_kexp() hook (i.e. walk_kexp_n_lift_all()).
                */
                val temp_cl = dup_idk(curr_m_idx, n)
                val make_cl = KExpMkClosure(noid, n, [], (f_typ, loc))
                curr_lift_extra_decls = create_kdefval(temp_cl, f_typ, default_val_flags(),
                                                Some(make_cl), curr_lift_extra_decls, loc)
                AtomId(temp_cl)
            | Some((nv, _)) =>
                /*
                    There is already pre-declared closure, we just replace
                    function name with the constructed closure
                */
                AtomId(nv)
            | _ =>
                match kinfo_(n, loc) {
                | KFun (ref {kf_flags, kf_params, kf_rt, kf_closure={kci_arg}, kf_loc}) =>
                    /*
                    This case is very similar to the first one above, but in this
                    case the global function (and thus the function that does not use free variables)
                    is accessed as a value before we got to its declaration
                    and put Some((noid, noid, Some(func_type))) to the current substitution environment.
                    No big deal, we can still form the closure with null closure data pointer on-fly
                    */
                    if is_constructor(kf_flags) {
                        a // type constructors are handled separately
                    } else if kci_arg == noid {
                        // double-check that the function does not use free variables.
                        // If it does then in theory we should not get here.
                        val temp_cl = dup_idk(curr_m_idx, n)
                        val f_typ = get_kf_typ(kf_params, kf_rt, kf_loc)
                        val make_cl = KExpMkClosure(noid, n, [], (f_typ, loc))
                        curr_lift_extra_decls = create_kdefval(temp_cl, f_typ, default_val_flags(),
                                                               Some(make_cl), curr_lift_extra_decls, loc)
                        AtomId(temp_cl)
                    } else {
                        throw compile_err(loc,
                            f"for the function '{idk2str(n, loc)}' there is no corresponding closure")
                    }
                | _ => a
                }
            }
        | _ => a
        }

    // pass-by processing types
    fun walk_ktyp_n_lift_all(t: ktyp_t, loc: loc_t, callb: k_callb_t) = t
    fun walk_kexp_n_lift_all(e: kexp_t, callb: k_callb_t)
    {
        val saved_extra_decls = curr_lift_extra_decls
        curr_lift_extra_decls = []
        val e =
        match e {
        | KDefFun kf =>
            val {kf_name, kf_params, kf_rt, kf_body, kf_closure, kf_loc} = *kf
            val {kci_arg, kci_fcv_t, kci_make_fp, kci_wrap_f} = kf_closure
            fun create_defclosure(kf: kdeffun_t ref, code: kcode_t, loc: loc_t)
            {
                val {kf_name, kf_params, kf_rt, kf_closure={kci_make_fp=make_fp}, kf_flags, kf_loc} = *kf
                val kf_typ = get_kf_typ(kf_params, kf_rt, kf_loc)
                val (_, orig_freevars) = get_closure_freevars(kf_name, kf_loc)
                if orig_freevars == [] {
                    if !is_constructor(kf_flags) {
                        curr_subst_env.add(kf_name, (noid, Some(kf_typ)))
                    }
                    KExpNop(loc) :: code
                } else {
                    val cl_name = dup_idk(curr_m_idx, kf_name)
                    curr_subst_env.add(kf_name, (cl_name, None))
                    defined_so_far.add(cl_name)
                    val cl_args =
                    [:: for fv <- orig_freevars {
                        if !defined_so_far.mem(fv) {
                            throw compile_err(kf_loc, f"free variable '{idk2str(fv, kf_loc)}' \
                                            of '{idk2str(kf_name, kf_loc)}' is not defined yet")
                        }
                        walk_atom_n_lift_all(AtomId(fv), kf_loc, callb)
                    } ]
                    val make_cl = KExpMkClosure(make_fp, kf_name, cl_args, (kf_typ, kf_loc))
                    create_kdefval(cl_name, kf_typ, default_val_flags(), Some(make_cl), code, kf_loc)
                }
            }

            val def_fcv_t_n_make =
                if kci_fcv_t == noid {
                    []
                } else {
                    val kcv =
                    match kinfo_(kci_fcv_t, kf_loc) {
                        | KClosureVars kcv => kcv
                        | _ => throw compile_err(kf_loc,
                            f"closure type '{idk2str(kci_fcv_t, kf_loc)}' for '{idk2str(kf_name, kf_loc)}' \
                                information is not valid (should be KClosureVars ...)")
                        }
                    val make_kf =
                        match kinfo_(kci_make_fp, kf_loc) {
                        | KFun make_kf => make_kf
                        | _ => throw compile_err(kf_loc,
                            f"make_fp '{idk2str(kci_make_fp, kf_loc)}' for '{idk2str(kf_name, kf_loc)}' \
                                information is not valid (should be KClosureVars ...)")
                        }
                    [:: KDefFun(make_kf), KDefClosureVars(kcv) ]
                }
            val out_e =
                if kci_wrap_f != noid {
                    KDefFun(kf)
                } else {
                    /*
                        We may produce some extra definitions for each function, such as
                        the closure constructor definition, the closure data type definition etc.
                        those extra definitions will be put into curr_top_code.

                        In this case we put the lifted function and associated definitinos
                        to the 'curr_top_code', whereas the function definition itself in some nested block
                        is replaced with the actual closure creation call, because this is the best place
                        where all the free variables the function accesses should be defined.
                    */
                    val extra_code = create_defclosure(kf, [], kf_loc)
                    curr_top_code = extra_code.tl() + def_fcv_t_n_make + [:: KDefFun(kf)] + curr_top_code
                    extra_code.hd()
                }
            // form the prologue where we extract all free variables from the closure data
            val saved_dsf = defined_so_far.copy()
            val saved_clo = curr_clo
            val saved_subst_env = curr_subst_env.copy()
            // we set the current closure to handle recursive functions. If a function
            // (with free variables or not) calls itself recursively, we don't need another closure
            // we just pass the closure data parameter directly to it.
            curr_clo = kf_name
            for arg <- kf_params { defined_so_far.add(arg) }
            val prologue =
                if kci_fcv_t == noid {
                    []
                } else {
                    val kcv =
                        match kinfo_(kci_fcv_t, kf_loc) {
                        | KClosureVars kcv => kcv
                        | _ => throw compile_err(kf_loc,
                            f"closure type '{idk2str(kci_fcv_t, kf_loc)}' for '{idk2str(kf_name, kf_loc)}' \
                              information is not valid (should be KClosureVars ...)")
                        }
                    val {kcv_freevars, kcv_orig_freevars} = *kcv
                    val fold prologue = [] for (fv, t)@idx <- kcv_freevars, fv_orig <- kcv_orig_freevars {
                        if !defined_so_far.mem(fv_orig) {
                            throw compile_err(kf_loc,
                                f"free variable '{idk2str(fv_orig, kf_loc)}' of function \
                                 '{idk2str(kf_name, kf_loc)}' is not defined before the function body")
                        }
                        val fv_proxy = dup_idk(curr_m_idx, fv)
                        defined_so_far.add(fv_proxy)
                        val (t, kv_flags) =
                            match kinfo_(fv_orig, kf_loc) {
                            | KVal ({kv_typ, kv_flags}) => (kv_typ, kv_flags)
                            | _ => (t, default_val_flags())
                            }
                        val kv_flags = kv_flags.{ val_flag_tempref=false,
                                                  val_flag_arg=false,
                                                  val_flag_global=[] }
                        val e = KExpMem(kci_arg, idx, (t, kf_loc))
                        curr_subst_env.add(fv_orig, (fv_proxy, None))
                        val new_kv_flags = kv_flags.{val_flag_tempref=true}
                        create_kdefval(fv_proxy, t, new_kv_flags, Some(e), prologue, kf_loc)
                    }

                    val prologue = if self_referencing_functions.mem(kf_name) {
                        val cl_name = dup_idk(curr_m_idx, kf_name)
                        curr_subst_env.add(kf_name, (cl_name, None))
                        val kf_typ = get_kf_typ(kf_params, kf_rt, kf_loc)
                        val make_cl = KExpIntrin(IntrinMakeFPbyFCV, [:: AtomId(kf_name)], (kf_typ, kf_loc)) //TODO: Rename CLV
                        create_kdefval(cl_name, kf_typ, default_val_flags(), Some(make_cl), prologue, kf_loc)
                    } else {prologue}

                    match fv_env.find_opt(kf_name) {
                    | Some({fv_declared_inside, fv_called_funcs}) =>
                        fv_called_funcs.foldl(
                            fun (called_f, prologue) {
                                if fv_declared_inside.mem(called_f) { prologue }
                                else {
                                match kinfo_(called_f, kf_loc) {
                                | KFun called_kf =>
                                    val {kf_closure={kci_fcv_t=called_fcv_t}} = *called_kf
                                    if called_fcv_t == noid {
                                        prologue
                                    } else {
                                        create_defclosure(called_kf, prologue, kf_loc)
                                    }
                                | _ => prologue
                                }}
                            }, prologue)
                    | _ => throw compile_err(kf_loc, f"missing 'lambda lifting' information about \
                                            function '{idk2str(kf_name, kf_loc)}'")
                    }
                }
            // now process the body.
            val body_loc = get_kexp_loc(kf_body)
            val body = code2kexp(prologue.rev() + kexp2code(kf_body), body_loc)
            val body = walk_kexp_n_lift_all(body, callb)
            // restore everything. We don't know anything about locally defined values etc. anymore.
            defined_so_far = saved_dsf
            curr_clo = saved_clo
            curr_subst_env = saved_subst_env
            *kf = kf->{kf_body=body}
            out_e
        | KDefExn (ref {ke_tag}) =>
            defined_so_far.add(ke_tag)
            walk_kexp(e, callb)
        | KDefVal(n, rhs, loc) =>
            val rhs = walk_kexp_n_lift_all(rhs, callb)
            defined_so_far.add(n)
            KDefVal(n, rhs, loc)
        | KDefVariant kvar =>
            val {kvar_ifaces=kvar_ifaces0} = *kvar
            val e = walk_kexp(e, callb)
            *kvar = kvar->{kvar_ifaces=kvar_ifaces0}
            e
        | KExpFor(idom_l, at_ids, body, flags, loc) =>
            val idom_l = [:: for (i, dom_i) <- idom_l {
                             val dom_i = check_n_walk_dom(dom_i, loc, callb)
                             defined_so_far.add(i)
                             (i, dom_i)
                            } ]
            for i <- at_ids { defined_so_far.add(i) }
            val body = walk_kexp_n_lift_all(body, callb)
            KExpFor(idom_l, at_ids, body, flags, loc)
        | KExpMap(e_idom_ll, body, flags, (etyp, eloc) as kctx) =>
            val e_idom_ll =
            [:: for (e, idom_l, at_ids) <- e_idom_ll {
                val e = walk_kexp_n_lift_all(e, callb)
                val fold idom_l = [] for (i, dom_i) <- idom_l {
                    val dom_i = check_n_walk_dom(dom_i, eloc, callb)
                    defined_so_far.add(i)
                    (i, dom_i) :: idom_l
                }
                for i <- at_ids { defined_so_far.add(i) }
                (e, idom_l.rev(), at_ids)
            } ]
            val body = walk_kexp_n_lift_all(body, callb)
            KExpMap(e_idom_ll, body, flags, kctx)
        | KExpMkClosure(make_fp, f, args, (typ, loc)) =>
            val args = [:: for a <- args { walk_atom_n_lift_all(a, loc, callb) } ]
            KExpMkClosure(make_fp, f, args, (typ, loc))
        | KExpCall(f, args, (_, loc) as kctx) =>
            val args = [:: for a <- args { walk_atom_n_lift_all(a, loc, callb) } ]
            val curr_f = curr_clo
            if f == curr_f {
                KExpCall(f, args, kctx)
            } else {
                match kinfo_(f, loc) {
                | KFun (ref {kf_closure={kci_fcv_t}}) =>
                    if kci_fcv_t == noid { KExpCall(f, args, kctx) }
                    else { KExpCall(check_n_walk_id(f, loc, callb), args, kctx) }
                | _ => KExpCall(check_n_walk_id(f, loc, callb), args, kctx)
                }
            }
        | KExpIntrin (intt, args, kctx) =>
            // We are disabling substitution of function with internal fp inside the macros, creating this fp.
            match intt{
            |IntrinMakeFPbyFCV => KExpIntrin (intt, args, kctx)
            |_ => walk_kexp(e, callb)
            }
        | _ => walk_kexp(e, callb)
        }
        val e = if curr_lift_extra_decls == [] { e }
                else { rcode2kexp(e :: curr_lift_extra_decls, get_kexp_loc(e)) }
        curr_lift_extra_decls = saved_extra_decls
        e
    }

    val walk_n_lift_all_callb = k_callb_t
    {
        kcb_atom=Some(walk_atom_n_lift_all),
        kcb_ktyp=Some(walk_ktyp_n_lift_all),
        kcb_kexp=Some(walk_kexp_n_lift_all)
    }

    [:: for km <- kmods {
        val {km_idx, km_top} = km
        curr_m_idx = km_idx
        curr_top_code = []
        self_referencing_functions.clear()
        for e <- km_top { fold_defcl_kexp_(e, defcl_callb) }
        for e <- km_top {
            val e = walk_kexp_n_lift_all(e, walk_n_lift_all_callb)
            curr_top_code = e :: curr_top_code
        }
        km.{km_top=curr_top_code.rev()}
    }]
}